{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "epochs = 15"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction to Split Neural Network (SplitNN)\n",
    "\n",
    "Traditionally, PySyft has been used to facilitate federated learning. However, we can also leverage the tools included in this framework to implement distributed neural networks. \n",
    "\n",
    "### What is a SplitNN?\n",
    "\n",
    "The training of a neural network (NN) is 'split' accross one or more hosts. Each model segment is a self contained NN that feeds into the segment in front. In this example Alice has unlabeled training data and the bottom of the network whereas Bob has the corresponding labels and the top of the network. The image below shows this training process where Bob has all the labels and there are multiple alices with <i>X</i> data [[1](https://arxiv.org/abs/1810.06060)]. Once $Alice_1$ has trained she sends a copy of her trained bottom model to the next Alice. This continues until $Alice_n$ has trained.\n",
    "\n",
    "<img src=\"images/P2P-DL.png\" width=\"50%\" alt=\"dominating_sets_example2\">\n",
    "\n",
    "In this case, both parties can train the model without knowing each others data or full details of the model. When Alice is finished training, she passes it to the next person with data.\n",
    "\n",
    "### Why use a SplitNN?\n",
    "\n",
    "The SplitNN has been shown to provide a dramatic reduction to the computational burden of training while maintaining higher accuracies when training over large number of clients [[2](https://arxiv.org/abs/1812.00564)]. In the figure below, the Blue line denotes distributed deep learning using SplitNN, red line indicate federated learning (FL) and green line indicates Large Batch Stochastic Gradient Descent (LBSGD).\n",
    "\n",
    "<img src=\"images/AccuracyvsFlops.png\" width=\"100%\">\n",
    "\n",
    "<img src=\"images/computation.png\" width=\"60%\">\n",
    "<img src=\"images/bandwidth.png\" width=\"60%\">\n",
    " \n",
    "Table 1 shows computational resources consumed when training CIFAR 10 over VGG. Theses are a fraction of the resources of FL and LBSGD. Table 2 shows the bandwith usage when training CIFAR 100 over ResNet. Federated learning is less bandwidth intensive with fewer than 100 clients. However, the SplitNN outperforms other approaches as the number of clients grow[ [2](https://arxiv.org/abs/1812.00564)].\n",
    "\n",
    "\n",
    "### Advantages\n",
    "\n",
    "- The accuracy should be identical to a non-split version of the same model, trained locally. \n",
    "- the model is distributed, meaning all segment holders must consent in order to aggregate the model at the end of training.\n",
    "- The scalability of this approach, in terms of both network and computational resources, could make this an a valid alternative to FL and LBSGD, particularly on low power devices.\n",
    "- This could be an effective mechanism for both horizontal and vertical data distributions.\n",
    "- As computational cost is already quite low, the cost of applying homomorphic encryption is also minimised.\n",
    "- Only activation signal gradients are sent/ recieved, meaning that malicious actors cannot use gradients of model parameters to reverse engineer the original values.\n",
    "\n",
    "### Constraints\n",
    "\n",
    "- A new technique with little surroundung literature, a large amount of comparison and evaluation is still to be performed.\n",
    "- This approach requires all hosts to remain online during the entire learning process (less fesible for hand-held devices).\n",
    "- Not as established in privacy-preserving toolkits as FL and LBSGD.\n",
    "- Activation signals and their corresponding gradients still have the capacity to leak information, however this is yet to be fully addressed in the literature.\n",
    "\n",
    "### Tutorial \n",
    "\n",
    "This tutorial demonstrates a basic example of SplitNN which\n",
    "\n",
    "- Has two paticipants: Alice and Bob.\n",
    "    - Bob has <i>labels</i>\n",
    "    - Alice has <i>X</i> values\n",
    "- Has two model segments.\n",
    "    - Alice has the bottom half\n",
    "    - Bob has the top half\n",
    "- Trains on the MNIST dataset.\n",
    "\n",
    "\n",
    "Authors:\n",
    "- Adam J Hall - Twitter: [@AJH4LL](https://twitter.com/AJH4LL) · GitHub:  [@H4LL](https://github.com/H4LL)\n",
    "- Théo Ryffel - Twitter: [@theoryffel](https://twitter.com/theoryffel) · GitHub: [@LaRiffle](https://github.com/LaRiffle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torchvision import datasets, transforms\n",
    "from torch import nn, optim\n",
    "import syft as sy\n",
    "hook = sy.TorchHook(torch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data preprocessing\n",
    "transform = transforms.Compose([transforms.ToTensor(),\n",
    "                              transforms.Normalize((0.5,), (0.5,)),\n",
    "                              ])\n",
    "trainset = datasets.MNIST('mnist', download=True, train=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "\n",
    "# Define our model segments\n",
    "\n",
    "input_size = 784\n",
    "hidden_sizes = [128, 640]\n",
    "output_size = 10\n",
    "\n",
    "models = [\n",
    "    nn.Sequential(\n",
    "                nn.Linear(input_size, hidden_sizes[0]),\n",
    "                nn.ReLU(),\n",
    "                nn.Linear(hidden_sizes[0], hidden_sizes[1]),\n",
    "                nn.ReLU(),\n",
    "    ),\n",
    "    nn.Sequential(\n",
    "                nn.Linear(hidden_sizes[1], output_size),\n",
    "                nn.LogSoftmax(dim=1)\n",
    "    )\n",
    "]\n",
    "\n",
    "# Create optimisers for each segment and link to their segment\n",
    "optimizers = [\n",
    "    optim.SGD(model.parameters(), lr=0.03,)\n",
    "    for model in models\n",
    "]\n",
    "\n",
    "# create some workers\n",
    "alice = sy.VirtualWorker(hook, id=\"alice\")\n",
    "bob = sy.VirtualWorker(hook, id=\"bob\")\n",
    "workers = alice, bob\n",
    "\n",
    "# Send Model Segments to starting locations\n",
    "model_locations = [alice, bob]\n",
    "\n",
    "for model, location in zip(models, model_locations):\n",
    "    model.send(location)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pixiedust": {
     "displayParams": {}
    }
   },
   "outputs": [],
   "source": [
    "def train(x, target, models, optimizers):\n",
    "    # Training Logic\n",
    "\n",
    "    #1) erase previous gradients (if they exist)\n",
    "    for opt in optimizers:\n",
    "        opt.zero_grad()\n",
    "\n",
    "    #2) make a prediction\n",
    "    a = models[0](x)\n",
    "\n",
    "    #3) break the computation graph link, and send the activation signal to the next model\n",
    "    remote_a = a.move(models[1].location, requires_grad=True)\n",
    "\n",
    "    #4) make prediction on next model using recieved signal\n",
    "    pred = models[1](remote_a)\n",
    "\n",
    "    #5) calculate how much we missed\n",
    "    criterion = nn.NLLLoss()\n",
    "    loss = criterion(pred, target)\n",
    "\n",
    "    #6) figure out which weights caused us to miss\n",
    "    loss.backward()\n",
    "\n",
    "    # 7) send gradient of the recieved activation signal to the model behind\n",
    "    # grad_a = remote_a.grad.copy().move(models[0].location)\n",
    "\n",
    "    # 8) backpropagate on bottom model given this gradient\n",
    "    # a.backward(grad_a)\n",
    "\n",
    "    #9) change the weights\n",
    "    for opt in optimizers:\n",
    "        opt.step()\n",
    "\n",
    "    #10) print our progress\n",
    "    return loss.detach().get()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 - Training loss: 0.5366032123565674\n",
      "Epoch 1 - Training loss: 0.25951144099235535\n",
      "Epoch 2 - Training loss: 0.19616961479187012\n",
      "Epoch 3 - Training loss: 0.1603524535894394\n",
      "Epoch 4 - Training loss: 0.13461314141750336\n",
      "Epoch 5 - Training loss: 0.11615928262472153\n",
      "Epoch 6 - Training loss: 0.10251112282276154\n",
      "Epoch 7 - Training loss: 0.0917714536190033\n",
      "Epoch 8 - Training loss: 0.081630177795887\n",
      "Epoch 9 - Training loss: 0.07489361613988876\n",
      "Epoch 10 - Training loss: 0.06842804700136185\n",
      "Epoch 11 - Training loss: 0.06335976719856262\n",
      "Epoch 12 - Training loss: 0.05796955153346062\n",
      "Epoch 13 - Training loss: 0.05360636115074158\n",
      "Epoch 14 - Training loss: 0.04987649992108345\n"
     ]
    }
   ],
   "source": [
    "for i in range(epochs):\n",
    "    running_loss = 0\n",
    "    for images, labels in trainloader:\n",
    "        images = images.send(alice)\n",
    "        images = images.view(images.shape[0], -1)\n",
    "        labels = labels.send(bob)\n",
    "        \n",
    "        loss = train(images, labels, models, optimizers)\n",
    "        running_loss += loss\n",
    "\n",
    "    else:\n",
    "        print(\"Epoch {} - Training loss: {}\".format(i, running_loss/len(trainloader)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
